#!/usr/bin/env python

#####################################################################
#
# queuestat is a tool for summarizing queue statistics of all ports. 
#
#####################################################################

import argparse
import cPickle as pickle
import datetime
import os.path
import swsssdk
import sys

from collections import namedtuple, OrderedDict
from natsort import natsorted
from tabulate import tabulate


QueueStats = namedtuple("QueueStats", "queueindex, queuetype, totalpacket, totalbytes, droppacket, dropbytes")
header = ['Port', 'TxQ', 'Counter/pkts', 'Counter/bytes', 'Drop/pkts', 'Drop/bytes']

counter_bucket_dict = {
    'SAI_QUEUE_STAT_PACKETS': 2,
    'SAI_QUEUE_STAT_BYTES': 3,
    'SAI_QUEUE_STAT_DROPPED_PACKETS': 4,
    'SAI_QUEUE_STAT_DROPPED_BYTES': 5,
}

STATUS_NA = 'N/A'
STATUS_INVALID = 'INVALID'

QUEUE_TYPE_MC = 'MC'
QUEUE_TYPE_UC = 'UC'
QUEUE_TYPE_ALL = 'ALL'
SAI_QUEUE_TYPE_MULTICAST = "SAI_QUEUE_TYPE_MULTICAST"
SAI_QUEUE_TYPE_UNICAST = "SAI_QUEUE_TYPE_UNICAST"
SAI_QUEUE_TYPE_ALL = "SAI_QUEUE_TYPE_ALL"

COUNTER_TABLE_PREFIX = "COUNTERS:"
COUNTERS_PORT_NAME_MAP = "COUNTERS_PORT_NAME_MAP"
COUNTERS_QUEUE_NAME_MAP = "COUNTERS_QUEUE_NAME_MAP"
COUNTERS_QUEUE_TYPE_MAP = "COUNTERS_QUEUE_TYPE_MAP"
COUNTERS_QUEUE_INDEX_MAP = "COUNTERS_QUEUE_INDEX_MAP"
COUNTERS_QUEUE_PORT_MAP = "COUNTERS_QUEUE_PORT_MAP"

cnstat_dir = 'N/A'
cnstat_fqn_file = 'N/A'

class Queuestat(object):
    def __init__(self):
        self.db = swsssdk.SonicV2Connector(host='127.0.0.1')
        self.db.connect(self.db.COUNTERS_DB)

        def get_queue_port(table_id):
            port_table_id = self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_PORT_MAP, table_id)
            if port_table_id is None:
                print "Port is not available!", table_id
                sys.exit(1)

            return port_table_id

        # Get all ports
        self.counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_PORT_NAME_MAP)
        if self.counter_port_name_map is None:
            print "COUNTERS_PORT_NAME_MAP is empty!"
            sys.exit(1)

        self.port_queues_map = {}
        self.port_name_map = {}

        for port in self.counter_port_name_map:
            self.port_queues_map[port] = {}
            self.port_name_map[self.counter_port_name_map[port]] = port

        # Get Queues for each port
        counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_QUEUE_NAME_MAP)
        if counter_queue_name_map is None:
            print "COUNTERS_QUEUE_NAME_MAP is empty!"
            sys.exit(1)

        for queue in counter_queue_name_map:
            port = self.port_name_map[get_queue_port(counter_queue_name_map[queue])]
            self.port_queues_map[port][queue] = counter_queue_name_map[queue]

    def get_cnstat(self, queue_map):
        """
            Get the counters info from database.
        """
        def get_counters(table_id):
            """
                Get the counters from specific table.
            """
            def get_queue_index(table_id):
                queue_index =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_INDEX_MAP, table_id)
                if queue_index is None:
                    print "Queue index is not available!", table_id
                    sys.exit(1)

                return queue_index

            def get_queue_type(table_id):
                queue_type =  self.db.get(self.db.COUNTERS_DB, COUNTERS_QUEUE_TYPE_MAP, table_id)
                if queue_type is None:
                    print "Queue Type is not available!", table_id
                    sys.exit(1)
                elif queue_type == SAI_QUEUE_TYPE_MULTICAST:
                    return QUEUE_TYPE_MC
                elif queue_type == SAI_QUEUE_TYPE_UNICAST:
                    return QUEUE_TYPE_UC
                elif queue_type == SAI_QUEUE_TYPE_ALL:
                    return QUEUE_TYPE_ALL
                else:
                    print "Queue Type is invalid:", table_id, queue_type
                    sys.exit(1)

            fields = ["0","0","0","0","0","0"]
            fields[0] = get_queue_index(table_id)
            fields[1] = get_queue_type(table_id)

            for counter_name, pos in counter_bucket_dict.iteritems():
                full_table_id = COUNTER_TABLE_PREFIX + table_id
                counter_data =  self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
                if counter_data is None:
                    fields[pos] = STATUS_NA
                elif fields[pos] != STATUS_NA:
                    fields[pos] = str(int(counter_data))
            cntr = QueueStats._make(fields)
            return cntr

       # Build a dictionary of the stats
        cnstat_dict = OrderedDict()
        cnstat_dict['time'] = datetime.datetime.now()
        if queue_map is None:
            return cnstat_dict
        for queue in natsorted(queue_map):
            cnstat_dict[queue] = get_counters(queue_map[queue]) 
        return cnstat_dict

    def cnstat_print(self, port, cnstat_dict):
        """
            Print the cnstat.
        """
        table = []

        for key, data in cnstat_dict.iteritems():
            if key == 'time':
                continue
            table.append((port, data.queuetype + str(data.queueindex),
                        data.totalpacket, data.totalbytes,
                        data.droppacket, data.dropbytes))

        print tabulate(table, header, tablefmt='simple', stralign='right')
        print

    def cnstat_diff_print(self, port, cnstat_new_dict, cnstat_old_dict):
        """
            Print the difference between two cnstat results.
        """
        def ns_diff(newstr, oldstr):
            """
                Calculate the diff.
            """
            if newstr == STATUS_NA or oldstr == STATUS_NA:
                return STATUS_NA
            else:
                new, old = int(newstr), int(oldstr)
                return '{:,}'.format(new - old)

        table = []

        for key, cntr in cnstat_new_dict.iteritems():
            if key == 'time':
                continue
            old_cntr = None
            if key in cnstat_old_dict:
                old_cntr = cnstat_old_dict.get(key)

            if old_cntr is not None:
                table.append((port, cntr.queuetype + str(cntr.queueindex),
                            ns_diff(cntr.totalpacket, old_cntr.totalpacket),
                            ns_diff(cntr.totalbytes, old_cntr.totalbytes),
                            ns_diff(cntr.droppacket, old_cntr.droppacket),
                            ns_diff(cntr.dropbytes, old_cntr.dropbytes)))
            else:
                table.append((port, cntr.queuetype + str(cntr.queueindex),
                        cntr.totalpacket, cntr.totalbytes,
                        cntr.droppacket, cntr.dropbytes))

        print tabulate(table, header, tablefmt='simple', stralign='right')
        print

    def get_print_all_stat(self):
        # Get stat for each port
        for port in natsorted(self.counter_port_name_map):
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])

            cnstat_fqn_file_name = cnstat_fqn_file + port
            if os.path.isfile(cnstat_fqn_file_name):
                try:
                    cnstat_cached_dict = pickle.load(open(cnstat_fqn_file_name, 'r'))
                    print port + " Last cached time was " + str(cnstat_cached_dict.get('time'))
                    self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict)
                except IOError as e:
                    print e.errno, e
            else:
                self.cnstat_print(port, cnstat_dict)

    def get_print_port_stat(self, port):
        if not port in self.port_queues_map:
            print "Port doesn't exist!", port
            sys.exit(1)

        # Get stat for the port queried
        cnstat_dict = self.get_cnstat(self.port_queues_map[port])
        cnstat_fqn_file_name = cnstat_fqn_file + port
        if os.path.isfile(cnstat_fqn_file_name):
            try:
                cnstat_cached_dict = pickle.load(open(cnstat_fqn_file_name, 'r'))
                print "Last cached time was " + str(cnstat_cached_dict.get('time'))
                self.cnstat_diff_print(port, cnstat_dict, cnstat_cached_dict)
            except IOError as e:
                print e.errno, e
        else:
            self.cnstat_print(port, cnstat_dict)

    def save_fresh_stats(self):
        if not os.path.exists(cnstat_dir):
            try:
                os.makedirs(cnstat_dir)
            except IOError as e:
                print e.errno, e
                sys.exit(1)

        # Get stat for each port and save
        for port in natsorted(self.counter_port_name_map):
            cnstat_dict = self.get_cnstat(self.port_queues_map[port])
            try:
                pickle.dump(cnstat_dict, open(cnstat_fqn_file + port, 'w'))
            except IOError as e:
                print e.errno, e
                sys.exit(e.errno)
            else:
                print "Clear and update saved counters for " + port

def main():
    global cnstat_dir
    global cnstat_fqn_file

    parser  = argparse.ArgumentParser(description='Display the queue state and counters',
                                      version='1.0.0',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
  queuestat
  queuestat -p Ethernet0
  queuestat -c
  queuestat -d
""")

    parser.add_argument('-p', '--port', type=str, help='Show the queue conters for just one port', default=None)
    parser.add_argument('-c', '--clear', action='store_true', help='Clear previous stats and save new ones')
    parser.add_argument('-d', '--delete', action='store_true', help='Delete saved stats')
    args = parser.parse_args()

    save_fresh_stats = args.clear
    delete_all_stats = args.delete

    port_to_show_stats = args.port

    uid = str(os.getuid())
    cnstat_file = uid

    cnstat_dir = "/tmp/queuestat-" + uid
    cnstat_fqn_file = cnstat_dir + "/" + cnstat_file

    if delete_all_stats:
        for file in os.listdir(cnstat_dir):
            os.remove(cnstat_dir + "/" + file)

        try:
            os.rmdir(cnstat_dir)
            sys.exit(0)
        except IOError as e:
            print e.errno, e
            sys.exit(e)

    queuestat = Queuestat()

    if save_fresh_stats:
        queuestat.save_fresh_stats()
        sys.exit(0)

    if port_to_show_stats!=None:
        queuestat.get_print_port_stat(port_to_show_stats)
    else:
        queuestat.get_print_all_stat()

    sys.exit(0)

if __name__ == "__main__":
    main()
